{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "c1d31f54",
   "metadata": {},
   "source": [
    "# Notebook for preprocessing Wikipedia (English) dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "21eb8ed4",
   "metadata": {},
   "source": [
    "### Initilizing phonemizer and tokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d6ca5ee4",
   "metadata": {},
   "outputs": [],
   "source": [
    "import yaml\n",
    "\n",
    "config_path = \"Configs/config.yml\" # you can change it to anything else\n",
    "config = yaml.safe_load(open(config_path))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b52b79ee",
   "metadata": {},
   "outputs": [],
   "source": [
    "from phonemize import phonemize"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2b363b0a",
   "metadata": {},
   "outputs": [],
   "source": [
    "import phonemizer\n",
    "global_phonemizer = phonemizer.backend.EspeakBackend(language='en-us', preserve_punctuation=True,  with_stress=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "92d58c92",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import TransfoXLTokenizer\n",
    "tokenizer = TransfoXLTokenizer.from_pretrained(config['dataset_params']['tokenizer']) # you can use any other tokenizers if you want to"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2eb25417",
   "metadata": {},
   "source": [
    "### Process dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "25e5ae16",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "from datasets import load_dataset\n",
    "dataset = load_dataset(\"wikipedia\", \"20220301.en\")['train'] # you can use other version of this dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ca7ca2f3",
   "metadata": {},
   "outputs": [],
   "source": [
    "root_directory = \"./wiki_phoneme\" # set up root directory for multiprocessor processing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "92a578d6",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "num_shards = 50000\n",
    "\n",
    "def process_shard(i):\n",
    "    directory = root_directory + \"/shard_\" + str(i)\n",
    "    if os.path.exists(directory):\n",
    "        print(\"Shard %d already exists!\" % i)\n",
    "        return\n",
    "    print('Processing shard %d ...' % i)\n",
    "    shard = dataset.shard(num_shards=num_shards, index=i)\n",
    "    processed_dataset = shard.map(lambda t: phonemize(t['text'], global_phonemizer, tokenizer), remove_columns=['text'])\n",
    "    if not os.path.exists(directory):\n",
    "        os.makedirs(directory)\n",
    "    processed_dataset.save_to_disk(directory)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d73caf6f",
   "metadata": {},
   "outputs": [],
   "source": [
    "from pebble import ProcessPool\n",
    "from concurrent.futures import TimeoutError"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c21f9dcf",
   "metadata": {},
   "source": [
    "#### Note: You will need to run the following cell multiple times to process all shards because some will fail. Depending on how fast you process each shard, you will need to change the timeout to a longer value to make more shards processed before being killed.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "04261364",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "max_workers = 32 # change this to the number of CPU cores your machine has \n",
    "\n",
    "with ProcessPool(max_workers=max_workers) as pool:\n",
    "    pool.map(process_shard, range(num_shards), timeout=60)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b78caee6",
   "metadata": {},
   "source": [
    "### Collect all shards to form the processed dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0568da38",
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import load_from_disk, concatenate_datasets\n",
    "\n",
    "output = [dI for dI in os.listdir(root_directory) if os.path.isdir(os.path.join(root_directory,dI))]\n",
    "datasets = []\n",
    "for o in output:\n",
    "    directory = root_directory + \"/\" + o\n",
    "    try:\n",
    "        shard = load_from_disk(directory)\n",
    "        datasets.append(shard)\n",
    "        print(\"%s loaded\" % o)\n",
    "    except:\n",
    "        continue"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c1547f2c",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = concatenate_datasets(datasets)\n",
    "dataset.save_to_disk(config['data_folder'])\n",
    "print('Dataset saved to %s' % config['data_folder'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ce886d15",
   "metadata": {},
   "outputs": [],
   "source": [
    "# check the dataset size\n",
    "dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cdf6f6f6",
   "metadata": {},
   "source": [
    "### Remove unneccessary tokens from the pre-trained tokenizer\n",
    "The pre-trained tokenizer contains a lot of tokens that are not used in our dataset, so we need to remove these tokens. We also want to predict the word in lower cases because cases do not matter that much for TTS. Pruning the tokenizer is much faster than training a new tokenizer from scratch. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "28cec407",
   "metadata": {},
   "outputs": [],
   "source": [
    "from simple_loader import FilePathDataset, build_dataloader\n",
    "\n",
    "file_data = FilePathDataset(dataset)\n",
    "loader = build_dataloader(file_data, num_workers=32, batch_size=128)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0b7504eb",
   "metadata": {},
   "outputs": [],
   "source": [
    "special_token = config['dataset_params']['word_separator']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0fcb44a2",
   "metadata": {},
   "outputs": [],
   "source": [
    "# get all unique tokens in the entire dataset\n",
    "\n",
    "from tqdm import tqdm\n",
    "\n",
    "unique_index = [special_token]\n",
    "for _, batch in enumerate(tqdm(loader)):\n",
    "    unique_index.extend(batch)\n",
    "    unique_index = list(set(unique_index))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1445662d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# get each token's lower case\n",
    "\n",
    "lower_tokens = []\n",
    "for t in tqdm(unique_index):\n",
    "    word = tokenizer.decode([t])\n",
    "    if word.lower() != word:\n",
    "        t = tokenizer.encode([word.lower()])[0]\n",
    "        lower_tokens.append(t)\n",
    "    else:\n",
    "        lower_tokens.append(t)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fa2dea92",
   "metadata": {},
   "outputs": [],
   "source": [
    "lower_tokens = (list(set(lower_tokens)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2a76cda1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# redo the mapping for lower number of tokens\n",
    "\n",
    "token_maps = {}\n",
    "for t in tqdm(unique_index):\n",
    "    word = tokenizer.decode([t])\n",
    "    word = word.lower()\n",
    "    new_t = tokenizer.encode([word.lower()])[0]\n",
    "    token_maps[t] = {'word': word, 'token': lower_tokens.index(new_t)}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c1c94be2",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "with open(config['dataset_params']['token_maps'], 'wb') as handle:\n",
    "    pickle.dump(token_maps, handle)\n",
    "print('Token mapper saved to %s' % config['dataset_params']['token_maps'])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1c9e968e",
   "metadata": {},
   "source": [
    "### Test the dataset with dataloader\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f9025e2f",
   "metadata": {},
   "outputs": [],
   "source": [
    "from dataloader import build_dataloader\n",
    "\n",
    "train_loader = build_dataloader(dataset, batch_size=32, num_workers=0, dataset_config=config['dataset_params'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "70874215",
   "metadata": {},
   "outputs": [],
   "source": [
    "_, (words, labels, phonemes, input_lengths, masked_indices) = next(enumerate(train_loader))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "BERT",
   "language": "python",
   "name": "bert"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
